"""
Here are some features that are based on the absenteeism data set.
They are just supposed to be illustrative of the kind of features one might have.

Note (1): we use check_output to warn us if the output is not what we expect; this is
used in both the offline ETL and the online webservice. Use `check_output` to help
encode your expectations about the output of your functions and catch bugs early!

Note (2): we can tag the `aggregation` features with whatever key value pair makes sense
for us to discern/identify that we should not compute these features in an online setting.
"""

import pandas as pd
import pandera as pa

from hamilton.function_modifiers import check_output, extract_columns, tag


@tag(inject_at_inference_time="True")
@check_output(range=(20.0, 60.0), data_type=float)
def age_mean(age: pd.Series) -> float:
    """Average of age"""
    return age.mean()


age_zero_mean_schema = pa.SeriesSchema(
    float,
    checks=[
        pa.Check.in_range(-120.0, 120.0),
    ],
    nullable=False,
)


@check_output(schema=age_zero_mean_schema)
def age_zero_mean(age: pd.Series, age_mean: float) -> pd.Series:
    """Zero mean of age"""
    return age - age_mean


age_std_dev_schema = pa.SeriesSchema(
    float,
    checks=[
        pa.Check.in_range(0.0, 40.0),
    ],
)


@tag(inject_at_inference_time="True")
@check_output(range=(0.0, 40.0), data_type=float)
def age_std_dev(age: pd.Series) -> float:
    """Standard deviation of age."""
    return age.std()


age_zero_mean_unit_variance_schema = pa.SeriesSchema(
    float,
    checks=[
        pa.Check.in_range(-4.0, 4.0),
    ],
    nullable=False,
)


@check_output(schema=age_zero_mean_unit_variance_schema)
def age_zero_mean_unit_variance(age_zero_mean: pd.Series, age_std_dev: float) -> pd.Series:
    """Zero mean unit variance value of age"""
    return age_zero_mean / age_std_dev


@extract_columns("seasons_1", "seasons_2", "seasons_3", "seasons_4", fill_with=0)
def seasons_encoded(seasons: pd.Series) -> pd.DataFrame:
    """One hot encodes seasons into 4 dimensions:
    1 - first season
    2 - second season
    3 - third season
    4 - fourth season
    """
    return pd.get_dummies(seasons, prefix="seasons")


@extract_columns(
    "day_of_the_week_2",
    "day_of_the_week_3",
    "day_of_the_week_4",
    "day_of_the_week_5",
    "day_of_the_week_6",
    fill_with=0,
)
def day_of_week_encoded(day_of_the_week: pd.Series) -> pd.DataFrame:
    """One hot encodes day of week into five dimensions -- Saturday & Sunday weren't present.
    1 - Sunday, 2 - Monday, 3 - Tuesday, 4 - Wednesday, 5 - Thursday, 6 - Friday, 7 - Saturday.
    """
    return pd.get_dummies(day_of_the_week, prefix="day_of_the_week")


has_children_schema = pa.SeriesSchema(
    int,
    checks=[
        pa.Check.isin([0, 1]),
    ],
    nullable=False,
)


@check_output(schema=has_children_schema)
def has_children(son: pd.Series) -> pd.Series:
    """Single variable that says whether someone has any children or not."""
    return (son > 0).astype(int)


has_pet_schema = pa.SeriesSchema(
    int,
    checks=[
        pa.Check.isin([0, 1]),
    ],
    nullable=False,
)


@check_output(schema=has_pet_schema)
def has_pet(pet: pd.Series) -> pd.Series:
    """Single variable that says whether someone has any pets or not."""
    return (pet > 0).astype(int)


is_summer_schema = pa.SeriesSchema(
    int,
    checks=[
        pa.Check.isin([0, 1]),
    ],
    nullable=False,
)


@check_output(schema=is_summer_schema)
def is_summer(month_of_absence: pd.Series) -> pd.Series:
    """Is it summer in Brazil? i.e. months of December, January, February."""
    return month_of_absence.isin([1, 2, 12]).astype(int)
